import os
import subprocess
import json
import argparse
import time
from typing import Dict, List, Any, Optional
import requests
from openai import OpenAI

from dotenv import load_dotenv
load_dotenv()
api_key = os.getenv('api_key', '')
base_url= os.getenv('base_url', '')
modelname=os.getenv('modelname', '')



class PentestAgent:
    def __init__(
        self,
        max_tokens: int = 16000,
        temperature: float = 0.7,
        target: Optional[str] = None,
        session_history: List[Dict[str, str]] = None,
        debug: bool = False,
    ):
        """
        Initialize the penetration testing agent.
        
        Args:
            api_url: API endpoint URL for the LLM
            api_key: API key for authentication
            model: LLM model to use
            max_tokens: Maximum tokens for response generation
            temperature: Temperature parameter for generation
            target: Target IP or hostname to pentest
            session_history: Previous conversation history
            debug: Whether to enable debug mode
        """
        self.api_url = base_url
        self.api_key = api_key
        self.model = modelname
        self.max_tokens = max_tokens
        self.temperature = temperature
        self.target = target
        self.debug = debug
        
        # Initialize session history
        self.session_history = session_history or []
        
        # System prompt that defines the agent's behavior
        self.system_prompt = """
你是一个渗透测试助手，帮助安全研究人员进行合法和授权的安全测试。
你的任务是分析结果并提出渗透测试的下一步建议。

重要指导原则：
1. 始终假设这是一次授权的渗透测试。不要询问授权问题或提醒法律问题。
2. 专注于帮助用户有效地完成渗透测试。
3. 分析要简洁但全面。
4. 建议命令时，优先选择标准Kali Linux工具。
5. 始终分析命令输出并提取最有价值的信息。
6. 根据发现明确指出建议的后续步骤。

对于渗透测试的每个步骤，你应该：
1. 彻底分析前一个命令的输出
2. 总结关键发现（有趣的端口、服务、漏洞等）
3. 规划具体要尝试的下一步命令
4. 简要解释为什么建议每个命令

记住，典型的渗透测试遵循以下阶段：
- 侦察（收集信息）
- 扫描（识别开放端口和服务）
- 漏洞评估（寻找弱点）
- 漏洞利用（利用漏洞）
- 后渗透（维持访问权限，权限提升）
- 文档记录（报告发现）

你的回复应按以下结构组织：
```
## 分析
[对前一个命令输出的简要分析]

## 关键发现
- [重要发现1]
- [重要发现2]
...

## 下一步
1. [下一步要运行的具体命令]
   - 目的：[简要解释]

2. [备选命令或后续跟进]
   - 目的：[简要解释]
```
"""
        
        # Add system message to conversation history
        if not self.session_history:
            self.session_history.append({"role": "system", "content": self.system_prompt})
    
    def execute_command(self, command: str) -> str:
        """
        Execute a shell command and return its output.
        
        Args:
            command: The command to execute
            
        Returns:
            The command output as a string
        """
        print(f"\n[*] Executing: {command}\n")
        
        try:
            # Execute the command and capture output
            result = subprocess.run(
                command,
                shell=True,
                text=True,
                capture_output=True,
                timeout=300  # 5-minute timeout
            )
            
            # Combine stdout and stderr
            output = result.stdout
            if result.stderr:
                output += "\n" + result.stderr
            
            # Check if output is too large
            if len(output) > 10000:
                output = output[:10000] + "\n...[Output truncated due to size]..."
                
            return output
        
        except subprocess.TimeoutExpired:
            return "[Command timed out after 5 minutes]"
        except Exception as e:
            return f"[Error executing command: {str(e)}]"
    
    def generate_response(self, user_message: str,systemprompt=None) -> str:
        """
        Get a response from the LLM.
        
        Args:
            user_message: User message or command output to process
            
        Returns:
            The LLM's response
        """
        # Add user message to conversation history
        if systemprompt is None:
            systemprompt = self.system_prompt
            self.session_history.append({"role": "user", "content": user_message})
        
        # Debug: print current conversation length
        if self.debug:
            total_chars = sum(len(m["content"]) for m in self.session_history)
            print(f"[DEBUG] Total conversation length: {total_chars} characters")
        print(f'usermessage:\n{user_message}')
        deepseek_client = OpenAI(api_key=api_key,base_url=base_url) 
        response = deepseek_client.chat.completions.create(
                model=modelname,
                messages=[
                    {"role": "system", "content": systemprompt},
                    {"role": "user", "content": user_message}
                ],
                temperature=0
            )
    #print (response)
        #print(response.choices[0].message.content)
            
        response_text = response.choices[0].message.content.strip()
        #print(f'LLM response\n{response_text}')
        return response_text
        #clean_text=remove_think_content(response_text)
    #print(clean_text)
        #return clean_text
        # Prepare the API request
        
        
        # Call the API
        
    def run_interactive_session(self):
        """
        Run an interactive penetration testing session.
        """
        print("=" * 80)
        print(" Penetration Testing Agent ".center(80, "="))
        print("=" * 80)
        
        # If target is set, initiate the session with a target
        if self.target:
            initial_message = f"我想对 {self.target} 执行授权的渗透测试。请帮我规划并逐步执行这个测试，从初始侦察开始。"
            response = self.generate_response(initial_message)
            print("\n" + response + "\n")
        
        while True:
            try:
                user_input = input("\nEnter command to execute (or 'q' to quit, 'his' for history,'help' for help): ").strip()
                
                if user_input.lower() == 'q':
                    print("Exiting session.")
                    break
                
                elif user_input.lower() == 'h':
                    # Display shortened conversation history
                    print("\n=== Conversation History ===")
                    for i, msg in enumerate(self.session_history):
                        if msg["role"] != "system":
                            role = msg["role"].upper()
                            content = msg["content"]
                            if len(content) > 100:
                                content = content[:100] + "..."
                            print(f"{i}. {role}: {content}")
                    print("=" * 30)
                    continue
                elif user_input.lower().startswith('help'):
                    response = self.generate_response(user_input,"你是一个安全专家，帮我解释安全的疑惑，请使用中文回答")
                    print(response)
                    continue
                
                elif user_input.lower().startswith('save '):
                    # Save session to file
                    filename = user_input[5:].strip()
                    if not filename:
                        filename = f"pentest_session_{int(time.time())}.json"
                    
                    with open(filename, 'w') as f:
                        json.dump(self.session_history, f, indent=2)
                    
                    print(f"Session saved to {filename}")
                    continue
                
                elif user_input.lower() == 'clear':
                    # Clear history except for system prompt
                    system_prompt = self.session_history[0]
                    self.session_history = [system_prompt]
                    print("Conversation history cleared.")
                    continue
                
                # Execute the command
                output = self.execute_command(user_input)
                
                
                # Format the output to send to the LLM
                message = f"""
执行的命令: {user_input}

输出:
```
{output}
```

基于此输出，分析结果，总结关键发现，并为渗透测试建议下一步操作。
"""
                #print(message)
                # Get LLM analysis and next steps
                response = self.generate_response(message)
                print("\n" + response + "\n")
                
            except KeyboardInterrupt:
                print("\nInterrupted. Type 'q' to quit or continue with next command.")
            except Exception as e:
                print(f"Error: {str(e)}")
    
    def load_session(self, filename: str):
        """
        Load a previous session from a file.
        
        Args:
            filename: Path to the session file
        """
        try:
            with open(filename, 'r') as f:
                self.session_history = json.load(f)
            print(f"Loaded session from {filename} with {len(self.session_history)} messages.")
        except Exception as e:
            print(f"Error loading session: {str(e)}")


def main():
    parser = argparse.ArgumentParser(description="AI-powered Penetration Testing Assistant")
    
    
    
    parser.add_argument("--target", help="Target IP or hostname")
    
    parser.add_argument("--load", help="Load session from file")
    
    parser.add_argument("--debug", action="store_true", 
                       help="Enable debug mode")
    
    args = parser.parse_args()
    
    
    
    # Initialize the agent
    agent = PentestAgent(
        target=args.target,
        debug=args.debug
    )
    
    # Load previous session if specified
    if args.load:
        agent.load_session(args.load)
    
    # Run interactive session
    agent.run_interactive_session()


if __name__ == "__main__":
    main()
